package utils

import (
	"context"
	"errors"
	"fmt"
	"github.com/alomerry/home/core/extension"
	"github.com/golang-jwt/jwt/v4"
	"time"
)

const (
	AccessTokenPortalExpireTime time.Duration = 12 * time.Hour // 12 hours
	SessionValidPeriodTime      time.Duration = 12 * time.Hour
	TokenUserTypeUser           string        = "user"
	PortalTokenFormat           string        = "portal:%s"
)

func GeneratePortalAccessToken(ctx context.Context, account string) (string, error) {
	secret, err := getSecretKey(ctx)
	if err != nil {
		return "", err
	}
	token, err := genAccessToken(ctx, account, secret, TokenUserTypeUser, AccessTokenPortalExpireTime)
	if err != nil {
		return "", err
	}

	_, err = extension.GetRDB().SetEx(ctx, fmt.Sprintf(PortalTokenFormat, token), "", SessionValidPeriodTime).Result()
	if err != nil {
		return "", err
	}
	return token, nil
}

func ParsePortalAccessToken(ctx context.Context, accessToken string) (*jwt.Token, error) {
	token, err := parseAccessToken(ctx, accessToken)
	if err != nil {
		return nil, err
	}
	claims := token.Claims.(jwt.MapClaims)
	// sub := cast.ToString(claims["sub"])
	err = validateToken(ctx, claims, accessToken)
	if err != nil {
		return nil, err
	}
	return token, nil
}

func genAccessToken(ctx context.Context, account, secret, userType string, expireTime time.Duration) (string, error) {
	if userType != TokenUserTypeUser {
		return "", errors.New("不支持")
	}

	iat := time.Now().Unix()
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
		"iat": iat,                                     // 签发时间
		"exp": iat + int64(expireTime),                 // 过期时间
		"iss": "water",                                 // 签发者
		"aud": "water",                                 // 接收方
		"sub": fmt.Sprintf("%s:%s", userType, account), // 所面向的用户
	})
	out, err := token.SignedString([]byte(secret))
	if err != nil {
		return "", err
	}

	return out, nil
}

func parseAccessToken(ctx context.Context, accessToken string) (*jwt.Token, error) {
	token, err := jwt.Parse(accessToken, getKeyLookup(ctx))

	// TODO
	if err != nil {
		fmt.Println("parse failed")
		return nil, errors.New("parse failed")
	}
	return token, nil
}

func getKeyLookup(ctx context.Context) jwt.Keyfunc {
	return func(t *jwt.Token) (interface{}, error) {
		secret, err := getSecretKey(ctx)
		if err != nil {
			return nil, err
		}
		return []byte(secret), nil
	}
}

func getSecretKey(ctx context.Context) (string, error) {
	var (
		key = fmt.Sprintf("jwt:portal:sk")
	)
	secret, err := extension.GetRDB().Get(ctx, key).Result()
	if err != nil {
		return "", err
	}
	return secret, nil
}

func validateToken(ctx context.Context, claims jwt.MapClaims, accessToken string) error {
	exists, err := extension.GetRDB().Exists(ctx, fmt.Sprintf(PortalTokenFormat, accessToken)).Result()
	if err != nil {
		return err
	}

	if exists == 0 {
		return errors.New("not exists")
	}

	return nil
}
